--- ../../swvox/swvox/csrc/swvox_kernel.cu	2024-02-14 05:16:57.541407536 -0500
+++ ../google/swvox/swvox/csrc/swvox_kernel.cu	2024-02-13 10:46:24.584645005 -0500
@@ -46,14 +46,106 @@
        PackedTreeSpec<scalar_t>& __restrict__ tree,
        const scalar_t* __restrict__ xyz_ind,
        int64_t* node_id=nullptr) {
+
+    scalar_t _cube_sz;
+
     scalar_t xyz[3] = {xyz_ind[0], xyz_ind[1], xyz_ind[2]};
     transform_coord<scalar_t>(xyz, tree.offset, tree.scaling);
-    scalar_t _cube_sz;
+    
     return query_single_from_root<scalar_t>(data, tree.child,
             xyz, &_cube_sz, node_id);
 }
 
 template <typename scalar_t>
+__device__ __inline__ void get_tree_path_ptr(
+        torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits>
+        tree_data,
+        PackedTreeSpec<scalar_t>& __restrict__ tree,
+        const scalar_t* __restrict__ xyz_ind, 
+        const int64_t tree_max_depth,
+        scalar_t** data,
+        int64_t* node_id=nullptr) {
+
+        // Vector with all the sizes from root to leaf
+        scalar_t _cube_sz_ptr[MAX_TREE_DEPTH + 1];
+        scalar_t all_relative_pos[(MAX_TREE_DEPTH + 1) * 3];
+
+        if (tree_max_depth > MAX_TREE_DEPTH) {
+            printf("Error: max_tree_depth exceeds MAX_TREE_DEPTH, add it as a parameter or make bigger.\n");
+            //return nullptr;  // Add appropriate error handling
+        }
+
+        scalar_t xyz[3] = {xyz_ind[0], xyz_ind[1], xyz_ind[2]};
+        transform_coord<scalar_t>(xyz, tree.offset, tree.scaling);
+        
+        query_path_from_root(tree_data, 
+                             tree.child,
+                             xyz,
+                             all_relative_pos, 
+                             _cube_sz_ptr,
+                             data,
+                             tree_max_depth,
+                             node_id);
+        
+    }
+
+template <typename scalar_t>
+__global__ void query_path_single_kernel(
+        PackedTreeSpec<scalar_t> tree,
+        const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> indices,
+        torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> values_out,
+        torch::PackedTensorAccessor32<int64_t, 2, torch::RestrictPtrTraits> node_ids_out,
+        const int64_t tree_max_depth) {
+    CUDA_GET_THREAD_ID(tid, indices.size(0));
+
+    scalar_t* data_ptr[MAX_TREE_DEPTH + 1];
+    for (int i = 0; i < tree_max_depth + 1; ++i){
+        data_ptr[i] = nullptr;
+        node_ids_out[tid][i] = INVALID_NODE_ID;
+    }
+    
+    get_tree_path_ptr(tree.data, 
+                      tree, 
+                      &indices[tid][0], 
+                      tree_max_depth, 
+                      data_ptr, 
+                      &node_ids_out[tid][0]);
+    
+    for (int i = 0; i < tree_max_depth + 1; ++i){
+        if (data_ptr[i] != nullptr) {
+            for (int j = 0; j < tree.data.size(4); ++j){
+                values_out[tid][i][j] = data_ptr[i][j];
+            }
+        }
+    }
+}
+
+template <typename scalar_t>
+__global__ void query_path_single_kernel_backward(
+       PackedTreeSpec<scalar_t> tree,
+       const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> indices,
+       const torch::PackedTensorAccessor64<scalar_t, 3, torch::RestrictPtrTraits> grad_output,
+       torch::PackedTensorAccessor64<scalar_t, 5, torch::RestrictPtrTraits> grad_data_out,
+        const int64_t tree_max_depth) {
+    CUDA_GET_THREAD_ID(tid, indices.size(0));
+
+    scalar_t* data_ptr[MAX_TREE_DEPTH + 1];
+    for (int i = 0; i < MAX_TREE_DEPTH + 1; ++i){
+        data_ptr[i] = nullptr;
+    }
+    get_tree_path_ptr(grad_data_out, tree, &indices[tid][0], tree_max_depth, data_ptr);
+
+    for (int i = 0; i < tree_max_depth + 1; ++i){
+        if (data_ptr[i] != nullptr) {
+            for (int j = 0; j < grad_output.size(2); ++j){
+            // Thread safe, as there may be different queries (tid) affecting the same node (i))
+                atomicAdd(&data_ptr[i][j], grad_output[tid][i][j]);
+            }
+        }
+    }
+}
+
+template <typename scalar_t>
 __global__ void query_single_kernel(
         PackedTreeSpec<scalar_t> tree,
         const torch::PackedTensorAccessor32<scalar_t, 2, torch::RestrictPtrTraits> indices,
@@ -181,6 +273,66 @@
     return grad_data;
 }
 
+torch::Tensor query_vertical_path_backward(
+    TreeSpec& tree,
+    torch::Tensor indices,
+    torch::Tensor grad_output) {
+        
+tree.check();
+DEVICE_GUARD(indices);
+const auto Q = indices.size(0), N = tree.child.size(1),
+           K = grad_output.size(2), M = tree.child.size(0);
+const int blocks = CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS);
+const int tree_max_depth = tree.max_depth.item<int>();
+
+torch::Tensor grad_data = torch::zeros({M, N, N, N, K}, grad_output.options());
+
+AT_DISPATCH_FLOATING_TYPES(indices.type(), __FUNCTION__, [&] {
+    device::query_path_single_kernel_backward<scalar_t><<<blocks, CUDA_N_THREADS>>>(
+            tree,
+            indices.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
+            grad_output.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
+            grad_data.packed_accessor64<scalar_t, 5, torch::RestrictPtrTraits>(),
+            tree_max_depth);
+});
+
+CUDA_CHECK_ERRORS;
+return grad_data;
+}
+
+
+QueryResult query_vertical_path(TreeSpec& tree, torch::Tensor indices) {
+    // To be done
+    tree.check();
+    check_indices(indices);
+    DEVICE_GUARD(indices);
+
+    const auto Q = indices.size(0), K = tree.data.size(4);
+    const int tree_max_depth = tree.max_depth.item<int>();  // Extract the value from the 1-size tensor
+    
+    const int blocks = CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS);
+
+    torch::Tensor values = torch::empty({Q, tree_max_depth + 1, K}, indices.options());
+    auto node_ids_options = at::TensorOptions()
+                       .dtype(at::kLong)
+                       .layout(tree.child.layout())
+                       .device(tree.child.device());
+
+                       torch::Tensor node_ids = torch::empty({Q, tree_max_depth + 1}, node_ids_options);
+    AT_DISPATCH_FLOATING_TYPES(indices.type(), __FUNCTION__, [&] {
+        device::query_path_single_kernel<scalar_t><<<blocks, CUDA_N_THREADS>>>(
+                tree,
+                indices.packed_accessor32<scalar_t, 2, torch::RestrictPtrTraits>(),
+                values.packed_accessor64<scalar_t, 3, torch::RestrictPtrTraits>(),
+                node_ids.packed_accessor32<int64_t, 2, torch::RestrictPtrTraits>(),
+                tree_max_depth);
+    });
+    CUDA_CHECK_ERRORS;
+    return QueryResult(values, node_ids);
+}
+
+
+
 torch::Tensor calc_corners(
         TreeSpec& tree,
         torch::Tensor indexer) {
